set -x

export HOME=''
export ENV_PATH=""
export TRITON_CACHE_DIR="/tmp/triton"

OUTPUT_DIR='work_dirs/janus_sft'
if [ ! -d "$OUTPUT_DIR" ]; then
  mkdir -p "$OUTPUT_DIR"
fi
SCRIPT_NAME=$(basename "$0")
cp "$0" "${OUTPUT_DIR}/${SCRIPT_NAME}"

GPUS_PER_NODE=${GPUS_PER_NODE:-8}
export PYTHONPATH="$(pwd):$(pwd)/../"

MIRCO_BATCH_SIZE=${MIRCO_BATCH_SIZE:-16}
ACCUMULATIVE_COUNTS=${ACCUMULATIVE_COUNTS:-1}

# -m debugpy --connect 5680
MAX_LENGHT=4096
HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 $ENV_PATH/bin/torchrun \
  --nproc-per-node=$GPUS_PER_NODE \
  rig_train.py \
  --model Janus-1.3B \  # Model here
  --freeze-style mode1 \
  --datasets demo_data/.json \  # Data here
  --num-workers 4 \
  --lr 1e-3 \
  --wd 0.05 \
  --warmup-ratio 0.03 \
  --work-dir ${OUTPUT_DIR} \
  --log-interval 1 \
  --seed 42 \
  --checkpoint-interval 5000 \
  --hf-interval 400 \
  --dset-pack \
  --dset-cache-dir ../janus_sft_no_und \
  --mirco-batch-size 1 \
  --global-batch-size $((GPUS_PER_NODE*ACCUMULATIVE_COUNTS)) \
  --max-length $MAX_LENGHT \
  --pack-max-length $((MIRCO_BATCH_SIZE * MAX_LENGHT)) \
  --concat-before-pack \
  --group-by-length \
  --resume \
  --max-keep-ckpts 5 \
  --pack-len-type 'total_block' \
  --pack-extra-buffer-size 1000 \
  --gradient-sync-after-accumulate \
  2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
